Side Adapter Network for Open-Vocabulary Semantic Segmentation
CVPR23 Highlight
https://arxiv.org/pdf/2302.12242.pdf
https://proceedings.neurips.cc/paper_files/paper/2017/file/e7b24b112a44fdd9ee93bdf998c6ca0e-Paper.pdf
https://scrapbox.io/files/64ef026389de23001b500e78.png
ImageNetのSegmentation結果。各画像に対し、inferenceのための語彙として、ImageNetのcategoryとcoco categoryを組み合わせ、annotated categoryのマスクのみを可視化したもの
Q. annotated category?
https://scrapbox.io/files/64ef024b67e655001b2cd331.png
赤い部分は学習における逆伝播計算。frozenさせたCLIPを分類器として振る舞わせ、side adapter networkが提案マスクとattention biasを生成し、より深い層のCLIPが提案マスクごとの分類のlogitを作成させる。
推論の間、提案マスクとlogitは組み合わせられ、行列演算(Matmul)を通じて最終的な予測を得る
https://scrapbox.io/files/64ef0365b063ae001b8da939.png
SAN側の構造。SANは入力画像をtokenizeし、Query tokenを開始時に付加する。さらに、CLIP modelのtransformer layersから得られる中間特徴量を合成している。query, visual featuresはMLPによってattention biasやmask proposalを差kすエイする。
https://scrapbox.io/files/64ef042af6f3c9001c403770.png
CLIPのattention biasを用いてマスク予測を行う。ここで、K=Qなので、Self-attentionであることに注意
SLSトークン集合(=shadow CLS tocken copies)が作成され、CLIPに適用される。SLSトークンはattention biasの下で更新される。右図は異なる種類のtokenがどのように相互作用するかを示している。
黒:KeyによってQueryは更新されない。
白:KeyによってQueryが更新される。
灰色:attention biasの影響下でKeyによってQueryが更新
Q. attention bias?
Q. K, Vの正体
そもそも、CLIPをOpen-vocabのSemantic Segmentationに使用するのが挑戦的
CLIPはimage-wiseでのcontrastive trainingを行っており、pixel-wiseではないからである。
For Segmentationのデータセットでfine-tuneを行った方法もあるが、性能が低いと指摘
Language-driven semantic segmentation. In ICLR, 2022.
two-stageでマスクの精製とラベル付けを行う手法もあるが、提案手法はend-to-endでの予測が可能
SANはCLIPの機能を活用可能な軽量なvision transformerで構成される
Q. Swinならだめなん?
この出力は、mask proposalsとattentino biasから成る。
attention biasはCLIPとのself-attentionに使用され、mask proposalのclass予測に使用される。
実際には、浅いCLIP層の特徴をSANに融合し、denseなCLIP層の残りにattention biasを用いている
実装
detectron2を用いて実装されているようだ。
timmからVision Transformerを引っ張ってきている。
timmの設定で、clsトークンを削除したり、normを消したりして初期化している
Q. clsトークンとは?
Q. normとは?
Query tokenをパラメータとして初期化、std=0.02で標準化している
多分実装だと、SAN部分とadapter部分を分けて実装している
SAN
CLIPで画像から特徴抽出を行う側
SANのmodelとしては、clipのopen-source implementationから取得している。
https://github.com/mlfoundations/open_clip
謎のモジュール、open_clip、clip_visual_extractor, clip_rec_head, ov_classifierがある。
Side-Adapter
ViTを拡張して、SAN側(Frozen CLIP側からの出力を受け取りながら特徴抽出を行う)
SANのキモの部分
code:python
def forward_features(
self, image: torch.Tensor, clip_features: Listtorch.Tensor
) -> List[Dictstr, torch.Tensor]:
x, (h, w) = self.vit_model.patch_embed(image)
L = x.shape1 # token length
pos_embed = self.vit_model.pos_embed
ori_h, ori_w = self.vit_model.patch_embed.grid_size
if pos_embed.shape1 != L:
pos_embed = (
F.interpolate(
pos_embed.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2),
size=h, w,
mode="bicubic",
align_corners=False,
)
.flatten(2)
.permute(0, 2, 1)
)
pos_embed = torch.cat(
[self.query_pos_embed.expand(pos_embed.shape0, -1, -1), pos_embed], dim=1
)
x = torch.cat(
[self.query_embed.expand(x.shape0, -1, -1), x],
dim=1,
) # B, Q+L, C
x = x + pos_embed
x = self.vit_model.norm_pre(x)
x = self.fuse(0, x, clip_features, (h, w))
outs = []
for i, blk in enumerate(self.vit_model.blocks, start=1):
x = blk(x)
x = self.fuse(i, x, clip_features, (h, w))
if i in self.deep_supervision_idxs:
outs.append(
{
"query": x:, :-L, ...,
"x": x:, -L:, ...
.permute(0, 2, 1)
.reshape(x.shape0, x.shape-1, h, w),
}
)
if i < len(self.vit_model.blocks):
x = x + pos_embed
return outs
def fuse(
self,
block_idx: int,
x: torch.Tensor,
clip_features: Listtorch.Tensor,
spatial_shape: Tupleint, int,
) -> torch.Tensor:
if block_idx in self.fusion_map:
src_idx = self.fusion_mapblock_idx
L = spatial_shape0 * spatial_shape1
x = torch.cat(
[
x:, :-L, ...,
self.fusion_layersf"layer_{block_idx}"(
x:, -L:, ..., clip_featuressrc_idx, spatial_shape
),
],
dim=1,
)
log_first_n(
logging.INFO,
f"fuse clip {src_idx} to {block_idx}",
len(self.fusion_map),
)
return x
forward関数の時点で、これをvit_blockぶん呼び出して、構造を作っている。
正直、forward_features部分はほぼVision Transformerと同じなのであまり重要視しなくても良い
fuse関数がキモ、これさえ理解&implementさえできれば、detectron2がなくても再現が可能である。
各変数について理解していくしかない
self.fusion_map
これはclassのinitで撮っているmap、DIct(int, int)らしい
["0->0", "3->1", "6->2", "9->3"]のmapとやらから、{int(j): int(i) for i, j in [x.split("->") for x in fusion_map]}として取得しているっぽい?何を指しているかはpaper見ないとよくわからないか、、、?
Table 4, different feature fusion strategiesを表していると考えられる。
https://scrapbox.io/files/64f387a41d8cad001b952ba0.png
deeper feature tend to be more semanticである
CLS tokenは除いて、visual tokensのみをSANにfuseしたい。
特徴次元と特徴の数がCLIPとSANで異なるから、visual tokenをfeature mapに1 \times 1 convolutionを通じて変換し、channel次元とfeature map sizeを揃える。そして、対応するfeature mapをelement-wise additionで加算する。
ViT-B/16CLIPは12層、SANは8層のモデルであり、このmapが上のやつに対応しているっぽい!!
なお、直感的なデザインであるため、他の演算は将来研究としている。
だから、クラスのinitで初期化する際にCLS tokenの削除を行っている
L
spatial_shape[0]*spatial_shape[1]、何?
h, wらしい、全特徴量 = 系列長にしたいんちゃうかな
x
xとfusion_layersから出てきたやつを系列長方向にconcatしているっぽい?
self.fusion_layers
nn.ModuleListらしい。なんかinitで撮っている。
fusion_typeとやらは"add"らしい
nn.ModuleDict({f"layer_{tgt_idx}": build_fusion_layer(fusion_type, input_shape[src_idx].channels, vit_num_features) for tgt_idx, src_idx in x2sidemap.items()})
input_shapeとやらは、build_side_adapter_networkの引数であり、呼ばれているのはclip_visual_extracotorのoutput_shapeである。
clip_visual_extractorは、FeatureExtractorのoutput_shapeになる。
code:python
class FeatureExtractor(nn.Module):
def __init__(
self,
visual_encoder: VisionTransformer,
last_layer_idx: int = -1,
frozen_exclude=[],
):
super().__init__()
self.output_tokens = visual_encoder.output_tokens
self.image_size = visual_encoder.image_size
self.patch_size = visual_encoder.patch_size
self.grid_size = visual_encoder.grid_size
self.num_features = visual_encoder.ln_pre.normalized_shape0
self.input_patchnorm = visual_encoder.input_patchnorm
self.patchnorm_pre_ln = visual_encoder.patchnorm_pre_ln
self.conv1 = visual_encoder.conv1
# class embeddings and positional embeddings
self.class_embedding = visual_encoder.class_embedding
self.positional_embedding = visual_encoder.positional_embedding
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
self.patch_dropout = visual_encoder.patch_dropout
self.ln_pre = visual_encoder.ln_pre
if last_layer_idx == -1:
self.resblocks = visual_encoder.transformer.resblocks
self.last_output_idx = len(self.resblocks) + 1
else:
self.resblocks = visual_encoder.transformer.resblocks:last_layer_idx
self.last_output_idx = last_layer_idx + 1
#
self.frozen_exclude = frozen_exclude
self._freeze(self.frozen_exclude)
def forward(self, x: torch.Tensor):
if self.input_patchnorm:
raise NotImplementedError("input_patchnorm is not implemented yet.")
else:
x = self.conv1(x) # shape = width, grid, grid
_, _, h, w = x.shape
x = x.reshape(x.shape0, x.shape1, -1) # shape = width, grid ** 2
x = x.permute(0, 2, 1) # shape = grid ** 2, width
# class embeddings and positional embeddings
x = torch.cat(
[
self.class_embedding.to(x.dtype)
+ torch.zeros(
x.shape0, 1, x.shape-1, dtype=x.dtype, device=x.device
),
x,
],
dim=1,
) # shape = grid ** 2 + 1, width
pos_embed = self.positional_embedding.to(x.dtype)
pos_embed = resize_pos_embed2d(pos_embedNone, ..., self.grid_size, (h, w))0
x = x + pos_embed
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
x = self.patch_dropout(x)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
outputs = ClipOutput(spacial_shape=(h, w))
outputs.save(0, x)
for i, resblock in enumerate(self.resblocks, start=1):
x = resblock(x)
outputs.save(i, x)
return outputs
def _freeze(self, frozen_exclude):
if "all" in frozen_exclude:
return
for name, param in self.named_parameters():
if not any(exclude in name for exclude in frozen_exclude):
param.requires_grad = False
@property
def output_shapes(self):
return {
i: ShapeSpec(channels=self.num_features)
for i in range(self.last_output_idx)
}
@property
def size_divisibility(self):
return self.patch_size0
src_idxとtgt_idxについて
fusion_mapのあれに対応している。ここで、CLIP側がtarget, SAN側がsourceのidxになっていることに注意する。
input_shapeとやらがよくわからんので、LB回し終わったら一回確認、その後ada
結局、SAN側の第src_idx層の入力になるchannel数になる
add fusion layerとやらは以下
code:python
def build_fusion_layer(fusion_type: str, in_channels: int, out_channels: int):
if fusion_type == "add":
return AddFusion(in_channels, out_channels)
else:
raise ValueError("Unknown fusion type: {}".format(fusion_type))
class AddFusion(CNNBlockBase):
def __init__(self, in_channels, out_channels):
super().__init__(in_channels, out_channels, 1)
self.input_proj = nn.Sequential(
LayerNorm(in_channels),
Conv2d(
in_channels,
out_channels,
kernel_size=1,
),
)
weight_init.c2_xavier_fill(self.input_proj-1)
def forward(self, x: torch.Tensor, y: torch.Tensor, spatial_shape: tuple):
# x: N,L,C y: N,C,H,W
y = (
F.interpolate(
self.input_proj(y.contiguous()),
size=spatial_shape,
mode="bilinear",
align_corners=False,
)
.permute(0, 2, 3, 1)
.reshape(x.shape)
)
x = x + y
return x
CNNBlockBaseはDetectron2の負債
code:python
class CNNBlockBase(nn.Module):
"""
A CNN block is assumed to have input channels, output channels and a stride.
The input and output of forward() method must be NCHW tensors.
The method can perform arbitrary computation but must match the given
channels and stride specification.
Attribute:
in_channels (int):
out_channels (int):
stride (int):
"""
def __init__(self, in_channels, out_channels, stride):
"""
The __init__ method of any subclass should also contain these arguments.
Args:
in_channels (int):
out_channels (int):
stride (int):
"""
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.stride = stride
def freeze(self):
"""
Make this block not trainable.
This method sets all parameters to requires_grad=False,
and convert all BatchNorm layers to FrozenBatchNorm
Returns:
the block itself
"""
for p in self.parameters():
p.requires_grad = False
FrozenBatchNorm2d.convert_frozen_batchnorm(self)
return self
これを継承させるだけの簡単なやつ